'''
Description: 
Author: KouXichao
Date: 2020-11-16 15:31:04
LastEditors: KouXichao
LastEditTime: 2021-04-16 14:35:03
FilePath: /LayoutGen/tools/gen_line_data.py
Github: https://github.com/kouxichao
'''
import os
import argparse
import numpy as np
import cv2
# import Levenshtein
import math
from PIL import Image, ImageDraw, ImageFont
import distorsion_generator
import shutil
import random
import sys
sys.path.append(os.path.abspath('./'))
import config
import json
from xycut import bounds_order

class SynthLayout():
    '''
    description: 合成文档数据处理
    param {*} layout_folder: 合成文档数据集路径
    '''
    def __init__(self, layout_folder, warp=False):
        super(SynthLayout, self).__init__()
        self.layout_folder = layout_folder
        self.warp = warp
        self.images = np.load(os.path.join(layout_folder, 'image_names.npy'))
        self.images = self.images.tolist()
        self.all_bboxes = []
        for index in range(len(self.images)):
            bboxes = np.load(os.path.join(self.layout_folder, 'regionInfo/' + self.images[index].split('.')[0] + '/bboxes.npy'), allow_pickle=True)
            self.all_bboxes.append(bboxes)

        # Clear the line image directory if it exists and Create the directory.
        self.line_image_path = os.path.join(layout_folder, 'line_images/')
        if os.path.exists(self.line_image_path):  
            shutil.rmtree(self.line_image_path)
        if os.path.exists(os.path.join(self.layout_folder, 'gt.txt')):
            os.remove(os.path.join(self.layout_folder, 'gt.txt')) 
        os.mkdir(self.line_image_path)  
	
    def __len__(self):
        '''
        :return: 图像的数量
        '''
        return len(self.images)
    
    def pt_comp(self, pt1, pt2):
        if pt1[0] < pt2[0] and pt1[1] < pt2[1]:
            return -1
        if pt1[0] > pt2[0] and pt1[1] > pt2[1]:
            return 1
        return 0

    def isTitleText(self, imageName, cen_pt):
        f = open(os.path.join(self.layout_folder, 'xml/' + imageName + '.json'),encoding="utf-8")
        region_dict = json.load(f)
        for region in region_dict["shapes"]:
            label = region["label"]
            rect = region["points"]  
            if 'title' in label:
                if self.pt_comp(rect[0], cen_pt) == -1 and self.pt_comp(rect[2], cen_pt) == 1:
                    return True
        
        return False

    def get_imagename(self, index):
        ''' 
        :param: index 图像索引，范围为0-len(self.images)
        :return: 图像文件名（包含后缀）
        '''
        return self.images[index]

    def load_image_gt(self, index):
        '''
        :param index: 图像索引，范围为0-len(self.images).
        :return: 图像数据(numpy)，字符的框，词的框，文本行的框，图像名.
        '''

        assert index >= 0 and index < len(self.images)    
        img_path = os.path.join(self.layout_folder, 'image/' + self.images[index])
        #print("********* load image ******* ", img_path)
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # mask_path = os.path.join(self.layout_folder, 'image/mask_' + self.images[index])
        # image_mask = cv2.imread(mask_path, 0)#, cv2.IMREAD_COLOR)

        bboxes = self.all_bboxes[index]
        word_bboxes = bboxes[0]
        line_bboxes = bboxes[1]
        char_bboxes = bboxes[2]    
        # return image, char_bboxes, word_bboxes, np.ones((image.shape[0], image.shape[1]), np.float32), [], line_bboxes, self.image[index].split(".")[0] 

        return image, char_bboxes, word_bboxes, line_bboxes, self.images[index].split(".")[0] 

    # def draw_line_order(self, dr, rp, txt, name):   
    #          t_length = len(temp_order)
    #             for _ in range(t_length):
    #                 order.append(temp_order[-1])

    #                 p1 = tuple(anns[int(temp_order[-1][-1])]['points'][0])
    #                 p3 = tuple(anns[int(temp_order[-1][-1])]['points'][2])
    #                 dr_copy.rectangle([p1, p3], outline='black', width=3)
    #                 pos = (p1[0]-dr_copy.textsize(str(i))[0]-10, p1[1]) #((p1[0] + p3[0])//2, p1[1])
    #                 dr_copy.text(pos, str(i), font=ft, fill='blue')
    #                 line_pt.append(((p1[0] + p3[0])//2, (p1[1]+p3[1])//2))
    #                 i = i + 1

    #                 temp_order.pop()
        
    def getinfo(self, index):
        import json
        def dict_ann(transcription, points):
            return {'transcription':transcription,'points':points}

        image, _, word_bboxes, line_bboxes, imageName = self.load_image_gt(index)
        copyimage = image.copy()
        im = Image.fromarray(image)
        dr = ImageDraw.Draw(im)
        cpimg = Image.fromarray(copyimage)
        dr_copy = ImageDraw.Draw(cpimg)
        textsize = 35
        ft = ImageFont.truetype("./font/cn_title/simhei.ttf", textsize)    
        
        anns = []
        bounds = []
        line_y = dict() 
        line_pt = []
        for i in range(line_bboxes.shape[0]):
            text = ""                # 行文本
            for wb in word_bboxes[i]:  # 提取词语文字并合并为行文本
                text += wb[-1]
            text = text.strip()
            
            pts = np.array(line_bboxes[i]).reshape((-1)).tolist()
            points = [] 
            for j in range(int(len(pts)/2)):
                temp = [pts[j*2], pts[j*2 + 1]]
                points.append(temp)
            anns.append(dict_ann(text, points))

            #reading order
            minx = np.array(points).min(axis=0)[0]
            maxx = np.array(points).max(axis=0)[0]
            miny = np.array(points).min(axis=0)[1]
            maxy = np.array(points).max(axis=0)[1] 
            g = ('regions', 'TEXT', str(i))
            bounds.append((g, (minx, miny, maxx, maxy)))
            # line_y[g] = (miny + maxy) / 2
            line_pt.append(((minx+maxx)//2, (miny+maxy)//2))
            dr.rectangle([(minx,miny),(maxx, maxy)], outline='black', width=3) 
            pos = (minx-dr_copy.textsize(str(i))[0]-10, miny) #((minx+maxx)//2,miny)
            dr.text(pos, str(i), font=ft, fill='blue')
        dr.line(line_pt, fill='black', width=5, joint='curve')
        im.save(os.path.join(self.layout_folder, imageName) + 'pre_order.jpg')   # 保存绘制图像
       
        # reading order 
        order = []
        i = 0
        line_pt = []
        for g in bounds_order(bounds):
            order.append(g)
            p1 = tuple(anns[int(g[-1])]['points'][0])
            p3 = tuple(anns[int(g[-1])]['points'][2])
            dr_copy.rectangle([p1, p3], outline='black', width=3)
            pos = (p1[0]-dr_copy.textsize(str(i))[0]-10, p1[1]) #((p1[0] + p3[0])//2, p1[1])
            dr_copy.text(pos, str(i), font=ft, fill='blue')
            line_pt.append(((p1[0] + p3[0])//2, (p1[1]+p3[1])//2))
            i = i + 1
                    
        dr_copy.line(line_pt, fill='black', width=5, joint='curve')
        cpimg.save(os.path.join(self.layout_folder, imageName) + 'read_order.jpg')   # 保存绘制图像            
        print('-------------{}------------'.format(imageName))

        anns_file = os.path.join(self.layout_folder, "gt_det.txt")
        anns_str = json.dumps(anns, ensure_ascii=False)
        with open(anns_file, 'a', encoding='utf-8') as f:
            f.write(os.path.join("image/", imageName+'.jpg') + '\t' + anns_str + '\n')

    def crop_image_line(self, index, skewing_angle=5, distorsion_type=3, distorsion_orientation=0, margin_lr=10, margin_tb=3):
        '''
        description: 
            获取图像中的所有文本行图像及其文本信(目前只支持直线文本). TODO:扭曲文本行裁剪
        params:
            index(int): 文档图像索引
            warp(bool): 是否扭曲
            distorsion_type(int): 1: Sine wave, 2: Cosine wave, 3:Random
            distorsion_orientation(int):  0: Vertical (Up and down), 1: Horizontal (Left and Right), 2: Both  
        return: 
            imagesOfLine(list): 文本行图像列表
            textOfLine(list): 行文本列表
        '''
        image, _, word_bboxes, line_bboxes, imageName = self.load_image_gt(index)
        h, w = image.shape[0], image.shape[1]
        bg_imgName = imageName.split('_')[-1] + ".jpg"
        if self.warp:
            bg_imgName = imageName.split('_')[-2] + ".jpg"
        bg_img = Image.open(os.path.join("./content/paperbackground", bg_imgName))
        bg_img = bg_img.resize((config.picWidth, config.picHeight))
        # print(imageName, " ", bg_imgName)
        # input()
        # imageName = imageName + '_warp.jpg' if self.warp else imageName + '.jpg'
        imagesOfLine = [] # 文本行图像列表
        textOfLine = []   # 行文本列表
        print("imageName: ", imageName)
        vertical = False
        for i in range(line_bboxes.shape[0]):
            arr = np.array(line_bboxes[i]).reshape((-1))
            if arr[2] > arr[0]:
                # continue
                t = np.clip(arr[1]-margin_tb, 0, h-1) 
                b = np.clip(arr[5]+margin_tb, 0, h-1)
                l = np.clip(arr[0]-margin_lr, 0, w-1)
                r = np.clip(arr[2]+margin_lr, 0, w-1)
                vertical = False
            else:
                vertical = True
                # continue
                t = np.clip(arr[1]-margin_lr, 0, h-1) 
                b = np.clip(arr[3]+margin_lr, 0, h-1)
                l = np.clip(arr[4]-margin_tb, 0, w-1)
                r = np.clip(arr[0]+margin_tb, 0, w-1)                
            # print(arr[0], " ", arr[2])
            # print("l {}, r {}, t {}, b {}, h {}, w {}".format(l, r, t, b, h, w))
            cropped = Image.fromarray(image[t:b, l:r, :])
            # mask = Image.fromarray(image[t:b, l:r, :])
            bg_crop = bg_img.crop((l,t,r,b))
            
            if vertical:
                cropped = cropped.rotate(90, expand=1)
                bg_crop = bg_crop.rotate(90, expand=1)
            text = ""                # 行文本
            for wb in word_bboxes[i]:  # 提取词语文字并合并为行文本
                text += wb[-1]
            text = text.strip()

            # 是否为title
            # if not self.isTitleText(imageName.split('.')[0], ((l+r)//2, (t+b)//2)):
            #     continue

            ################
            # 图像倾斜旋转 #
            ################
            # copy = np.array(bg_crop.copy())
            # random_angle = random.randint(0 - skewing_angle, skewing_angle)
            # cropped = cropped.rotate(random_angle, expand=1, fillcolor=(int(copy[...,0].mean()), int(copy[...,1].mean()), int(copy[...,2].mean())))
            # rotated_mask = mask.rotate(
            #     skewing_angle if not random_skew else random_angle, expand=1
            # )

            #############################
            # Apply distorsion to image #
            #############################
            distorsion_type = 0#random.randint(0, 2)
            # print("distorsion_type: ", distorsion_type)
            if distorsion_type == 0:
                distorted_img = cropped  # Mind = blown
                distorted_mask = cropped # TODO mask
            elif distorsion_type == 1:
                distorted_img, distorted_mask = distorsion_generator.sin(
                    cropped,
                    cropped,
                    vertical=(distorsion_orientation == 0 or distorsion_orientation == 2),
                    horizontal=(distorsion_orientation == 1 or distorsion_orientation == 2),
                    bg_img=bg_crop,
                )
            elif distorsion_type == 2:
                distorted_img, distorted_mask = distorsion_generator.cos(
                    cropped,
                    cropped,
                    vertical=(distorsion_orientation == 0 or distorsion_orientation == 2),
                    horizontal=(distorsion_orientation == 1 or distorsion_orientation == 2),
                    bg_img=bg_crop,
                )
            else:
                distorted_img, distorted_mask = distorsion_generator.random(
                    cropped,
                    cropped,
                    vertical=(distorsion_orientation == 0 or distorsion_orientation == 2),
                    horizontal=(distorsion_orientation == 1 or distorsion_orientation == 2),
                    bg_img=bg_crop,
                )
            distorted_img = distorted_img.convert("RGB")
            distorted_img = np.array(distorted_img)

            ##################################
            # visualization for result image #
            ##################################
            # print("text of line: \n", text)
            # cv2.imshow("line image", cropped)
            # cv2.waitKey()

            ####################################################
            # save line image and line text for ocr train/eval #
            ####################################################               
            save_name = str(index) + "_" + str(i) + "line.jpg"
            cv2.imwrite(os.path.join(self.line_image_path, save_name), distorted_img)
            with open(os.path.join(self.layout_folder, "gt.txt"), 'a', encoding='utf-8') as f:
                f.write(os.path.join("line_images/", save_name) + '\t' + text + '\n')
                # f.write(os.path.join("line_images/", save_name))
                # f.write("\t")
                # f.write(text)
                # f.write("\n")
            imagesOfLine.append(cropped)  
            textOfLine.append(text)       
        return imagesOfLine, textOfLine

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description="!!!!!!!!!!!!generate line image from Layout data!!!!!!!!!!!!!")
    parser.add_argument("-id", "--image_dir", default="./POD_val_paper/", type=str, help="Layout data folder")
    parser.add_argument("-dt", "--data_type", default=2, type=int, help="which data to get")
    parser.add_argument("-sk", "--skewing_angle", default=5, type=int, help="Type of image to be inserted")
    parser.add_argument("-dist", "--distorsion_type", default=1, type=int, help="Type of image to be inserted")
    parser.add_argument("-do", "--distorsion_orientation", default=0, type=int, help="Type of image to be inserted")
    parser.add_argument("-ml", "--margin_lr", default=10, type=int, help="Type of image to be inserted")
    parser.add_argument("-mt", "--margin_tb", default=3, type=int, help="Type of image to be inserted")
    parser.add_argument("-w", "--warp", default=False, action='store_true', help="warp image")

    opt = parser.parse_args()

    # work_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    # image_dir = os.path.join(work_dir, "synthetic_data")
    dataloader = SynthLayout(opt.image_dir, warp=opt.warp)
    nums = dataloader.__len__()  

    if os.path.exists(os.path.join(dataloader.layout_folder, "gt.txt")):
        os.remove(os.path.join(dataloader.layout_folder, "gt.txt"))
    if os.path.exists(os.path.join(dataloader.layout_folder, "gt_det.txt")):
        os.remove(os.path.join(dataloader.layout_folder, "gt_det.txt"))

    if os.path.exists(os.path.join(dataloader.layout_folder, "line_images")):  
        shutil.rmtree(os.path.join(dataloader.layout_folder, "line_images"))
    os.mkdir(os.path.join(dataloader.layout_folder, "line_images"))  
    
    print("\033[33m Start !!!!! \033[0m")
    for i in range(nums):
        if opt.data_type == 0x01: 
            # data for ocrs
            dataloader.crop_image_line(i, opt.skewing_angle, opt.distorsion_type, \
                                                opt.distorsion_orientation, opt.margin_lr, opt.margin_tb)
        if opt.data_type == 0x02:
            # data for det
            dataloader.getinfo(i)

    print("\033[33m End !!!!! \033[0m")
